import matplotlib.pyplot as plt
import numpy as np
import os 
import pickle
import torch.nn as nn
import torch
import torch.nn.functional as F

from utils_others import pdist

top_dir = os.getcwd()
read_path_dict = {

    # 'CIFAR20_SEQ':'experiments/SEQ_CIFAR100/2023-04-23-23-07-41/tracking_result',
    # 'CIFAR20_REPLAY':'experiments/REPLAY_CIFAR100/2023-04-23-23-07-43/tracking_result',
    # 'CIFAR20_MTL':'experiments/MTL_CIFAR100/2023-04-24-15-32-30/tracking_result',
    # 'CIFAR20_RESNET_SEQ':'experiments/SEQ_CIFAR100/2023-04-24-10-27-35/tracking_result',
    # 'CIFAR20_RESNET_REPLAY':'experiments/REPLAY_CIFAR100/2023-04-24-10-08-55/tracking_result',
    'CIFAR20_RESNET_MTL':'experiments/MTL_CIFAR100/2023-04-25-12-09-57/tracking_result', 
}

read_path_dict = {k:os.path.join(top_dir,v) for k,v in read_path_dict.items()}
save_path = os.path.join(top_dir,'plots/tracking_result/')
if not os.path.isdir(save_path):
    os.makedirs(save_path)

total_task = 20
# color_list = ['r','b','g','y','m']
color_list = plt.get_cmap('tab20',total_task)

for k,v in read_path_dict.items():

    with open(v,'rb') as f:
        load_dict = pickle.load(f)
        cls_center, encoder_center = load_dict['cls_center_dict'], load_dict['encoder_center_dict']
    
    x_cnt = 0
    init_task_x = [0]
    num_class_task_list = [0]
    end_task_x = []
    for task_id in range(total_task):
        cls_center_mat_init, enc_center_mat_init = cls_center[task_id][0], encoder_center[task_id][0]
        num_all_seem_class = cls_center_mat_init.shape[0]
        assert enc_center_mat_init.shape[0] == num_all_seem_class
        num_class_task_list.append(num_all_seem_class)
        x_cnt+=len(cls_center[task_id])
        init_task_x.append(x_cnt)
        end_task_x.append(x_cnt-1)
    
    # =============================== Cls-Enc Distance ==========================================
    new_dist_list = []
    old_dist_list = [[] for _ in range(total_task)]
    last_dist_list = []
    for task_id in range(total_task):
        cls_encoder_dist = []
        for j in range(task_id,total_task):
            for cls_center_mat, enc_center_mat in zip(cls_center[j], encoder_center[j]):
                task_enc_center = F.normalize(enc_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
                task_cls_center = F.normalize(cls_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
                # _dist = nn.PairwiseDistance(p=2)(task_enc_center,task_cls_center) # Eclidean Distance
                _dist = 1-nn.CosineSimilarity(dim=-1)(task_enc_center,task_cls_center) # Cosine Distance
                # _dist = torch.arccos(_dist)/torch.pi*180
                # _dist = (pdist(task_enc_center)-pdist(task_cls_center)).abs().mean() # Struct Distance
                _dist = torch.mean(_dist[~torch.isnan(_dist)]).item()
                assert not np.isnan(_dist)

                cls_encoder_dist.append(_dist)

        last_dist_list.append(cls_encoder_dist[-1])
        new_dist_list.append(cls_encoder_dist[end_task_x[0]]) # cls_encoder_dist is started from the t_id-th task
        for t_id in range(task_id+1,total_task):
            old_dist_list[t_id].append(cls_encoder_dist[end_task_x[t_id-task_id]])

        plt.plot(list(range(init_task_x[task_id],init_task_x[-1])),cls_encoder_dist,color=color_list(task_id),linestyle='-',linewidth=4)
    
    print('%s: aver_old_mean = %.4f'%(k,np.mean([np.mean(lst) for lst in old_dist_list[1:]])))
    print('%s: aver_new_mean = %.4f'%(k,np.mean(new_dist_list[1:])))
    print('%s: aver_all_mean = %.4f'%(k,np.mean([np.mean(list(a)+[b]) for a,b in zip(old_dist_list[1:],new_dist_list[1:])])))
    
    print('%s: last_old_mean = %.4f'%(k,np.mean(last_dist_list[:-1])))
    print('%s: last_new_mean = %.4f'%(k,last_dist_list[-1]))
    print('%s: last_all_mean = %.4f'%(k,np.mean(last_dist_list)))

    plt.yticks(fontsize=12)
    plt.xticks(ticks=[x for x in init_task_x[:-1]],labels=['%d'%(i+1) for i in range(total_task)],fontsize=12)
    plt.ylabel('Feature-Embedding Distance',fontsize=15)
    plt.xlabel('Task',fontsize=15)
    plt.ylim(0,1.6)
    # if k in['CIFAR20_MTL','CIFAR20_RESNET_MTL']:
    #     plt.legend(['Task %d'%(i+1) for i in range(20)],bbox_to_anchor=(1.0,0.7), ncol=2)
    plt.savefig(os.path.join(save_path,'%s_cls_enc.pdf'%(k)),dpi=1200,bbox_inches='tight')
    plt.savefig(os.path.join(save_path,'%s_cls_enc.png'%(k)),dpi=1200,bbox_inches='tight')
    plt.clf()

    # # ================================= Enc Distance ==========================================
    # for task_id in range(total_task):
    #     for j in range(task_id,total_task):
    #         cls_encoder_dist = []
    #         task_enc_center_init = F.normalize(encoder_center[j][0][num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #         for cls_center_mat, enc_center_mat in zip(cls_center[j], encoder_center[j]):
    #             task_enc_center = F.normalize(enc_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #             task_cls_center = F.normalize(cls_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #             # _dist = nn.PairwiseDistance(p=2)(task_enc_center,task_cls_center) # Eclidean Distance
    #             _dist = 1-nn.CosineSimilarity(dim=-1)(task_enc_center,task_enc_center_init) # Cosine Distance
    #             # _dist = torch.arccos(_dist)/torch.pi*180
    #             # _dist = (pdist(task_enc_center)-pdist(task_enc_center_init)).abs().mean() # Struct Distance
    #             _dist = torch.mean(_dist[~torch.isnan(_dist)]).item()
    #             assert not np.isnan(_dist)
    #             cls_encoder_dist.append(_dist)
    #         plt.plot(list(range(init_task_x[j],init_task_x[j+1])),cls_encoder_dist,color=color_list(task_id),linestyle='-')
    
    # plt.xticks(ticks=[x for x in init_task_x[:-1]],labels=['%d'%(i+1) for i in range(total_task)])
    # plt.ylabel('Enc Distance')
    # plt.xlabel('Task')
    # plt.ylim(0,0.75)
    # plt.savefig(os.path.join(save_path,'%s_enc.png'%(k)),dpi=600)
    # plt.clf()

    # # ================================= Cls Distance ==========================================
    # for task_id in range(total_task):
        
    #     for j in range(task_id,total_task):
    #         cls_encoder_dist = []
    #         task_cls_center_init = F.normalize(cls_center[j][0][num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #         for cls_center_mat, enc_center_mat in zip(cls_center[j], encoder_center[j]):
    #             task_enc_center = F.normalize(enc_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #             task_cls_center = F.normalize(cls_center_mat[num_class_task_list[task_id]:num_class_task_list[task_id+1]],p=2,dim=-1)
    #             # _dist = nn.PairwiseDistance(p=2)(task_enc_center,task_cls_center) # Eclidean Distance
    #             _dist = 1-nn.CosineSimilarity(dim=-1)(task_cls_center,task_cls_center_init) # Cosine Distance
    #             # _dist = torch.arccos(_dist)/torch.pi*180
    #             _dist = torch.mean(_dist[~torch.isnan(_dist)]).item()
    #             assert not np.isnan(_dist)
    #             cls_encoder_dist.append(_dist)
    #         plt.plot(list(range(init_task_x[j],init_task_x[j+1])),cls_encoder_dist,color=color_list(task_id),linestyle='-')
    
    # plt.xticks(ticks=[x for x in init_task_x[:-1]],labels=['%d'%(i+1) for i in range(total_task)])
    # plt.ylabel('Cls Distance')
    # plt.xlabel('Task')
    # plt.ylim(0,0.75)
    # plt.savefig(os.path.join(save_path,'%s_cls.png'%(k)),dpi=600)
    # plt.clf()
        
